using Pkg
Pkg.add(["Distributed", "Distributions", "LinearAlgebra", "SpecialFunctions", "GSL", "DataFrames", "CSV"])


# It is important throughout that subnational regions cover rows and countries over columns. 
# when iterating over regions, first iterate over all regions in the first country then move on
# to the next country. Same goes in tau mat

using Distributed

if nprocs() == 1
    addprocs(3)
end


@everywhere using Distributed, Distributions, LinearAlgebra, SpecialFunctions, GSL, DataFrames, CSV

# If false the first region will have a different intra regional trade cost from the rest, which will have an intra-regional trade cost of 1. 
@everywhere symmetric_intra = true
@everywhere symmetric_distribution = true
@everywhere IO_flag = false

cd("$(homedir())/Documents/Personal Research/JAERE Extension/Draft/Code/appendix_code")

inter_trade_costs = range(1.0, stop = 1.2, length = 9)

@everywhere function intra_solve(iceberg, IO_flag)
    output_df = DataFrame(
        phi = Float64[],
        del = Float64[],
        reg = Float64[],
        country = Int64[],
        subnat = Int64[],
        labor = Float64[],
        wages = Float64[], 
        P_H = Float64[],
        Phi_network = Float64[],
        Del_network = Float64[],
        kappa = Float64[],
        dirty_input = Float64[],
        emissions = Float64[],
        ttl_output = Float64[],
        FA = Float64[],
        e_intensity = Float64[],
        inter_trade_cost = Float64[],
        intra_trade_cost = Float64[],
        emissions_tax = Float64[],
        Del_H = Float64[],
        mass_cus = Float64[],
        mass_sup = Float64[],
        int_sold_q = Float64[],
        int_purch_q = Float64[],
        int_sold_val = Float64[],
        int_purch_val = Float64[],
        g_chi = Float64[])

    intra_trade_costs = [1.2] 
    
    for intra in intra_trade_costs
        a_lab = 0.24
        b_lab = 0.8
        t = 0.134
        A = 1.0
        w = 1.0             # price of dirty input
        eps = 0.000094       # scale of emissions generation

        b = 0             # denominator addition for abatement
        max_kap = Inf # 10.0         # maximum unit price of dirty input

        # Who gets the tax revenue? 
        sig = 4.0       # elasticity of substitution
        if IO_flag == true
            alp = 1.
        else
            alp = 1.        # input suitability
        end
        a_io = 0.1
        m_phi = 0.0          # mean of fundamental productivity
        m_del = 0.          # mean of fundamental quality
        v_phi = 1.          # std of fundamental productivity
        v_del = 1.         # std of fundamental quality
        v_cor = 0.          # correlation between phi and del
        
        m_psi = 0.216
        m_psi_F = 0.6

        s_psi = 0.          # elasticity of psi function
        s_xi  = 0.9565       # measure of xi variance
        nu = 0.              #.1641          # relationship stickiness
        D = 1.              # size of dirty input
        L_bar = 10.             # Labor supply
        gam = 1.            # fundamental labor productivity assumed constant accross firms
        bet = 0.            # discount factor
        mu = sig / (sig-1)    # markup


        struc_params = Dict()

        struc_params["sig"] = sig
        struc_params["alp"] = alp
        struc_params["m_phi"] = m_phi
        struc_params["m_del"] = m_del                    
        struc_params["v_phi"] = v_phi
        struc_params["v_del"] = v_del
        struc_params["v_cor"] = v_cor
        struc_params["m_psi"] = m_psi
        struc_params["s_psi"] = s_psi
        struc_params["s_xi"] = s_xi
        struc_params["nu"] = nu
        struc_params["D"] = D
        struc_params["bet"] = bet
        struc_params["mu"] = mu
        # struc_params["m_phi_alt"] = m_phi_alt
        # struc_params["m_del_alt"] = m_del_alt                    
        # struc_params["v_phi_alt"] = v_phi_alt
        # struc_params["v_del_alt"] = v_del_alt
        # struc_params["v_cor_alt"] = v_cor_alt


        m_xi = 1.0
        lam_xi = m_xi / gamma(1 + 1 / s_xi) # CHECK LIM CODE

        qt_min = 0.05         # min quantile for phi/del distribution
        qt_max = 0.95        # max quantile for phi/del distribution
    
        n_grid = 15        # grid size
        tol = 0.001          # numerical tolerance for iterations
        max_iter = 200      # max iterations
        ondisp = 0          # iteration display toggle
        update = 0.9

        solv_params = Dict( "qt_min" => qt_min,
                            "qt_max" => qt_max,
                            "n_grid" => n_grid,
                            "tol" => tol,
                            "max_iter" => max_iter,
                            "ondisp" => ondisp)

        n_grid_phi = solv_params["n_grid"]
        n_grid_del = solv_params["n_grid"]
        n_grid_subnat = 2
        n_grid_reg = 4 #solv_params["n_grid"]
        n_grid_cntry = Int(n_grid_reg/n_grid_subnat)

        if haskey(struc_params, "A_phi")
            A_phi = struc_params["A_phi"]
            if length(A_phi) == 1
                A_phi = A_phi * ones(n_grid_phi,n_grid_del, n_grid_reg)
            end
        else 
            A_phi = ones(n_grid_phi, n_grid_del, n_grid_reg) 
            struc_params["A_phi"] = A_phi
        end 

        if haskey(struc_params, "A_del")
            A_del = struc_params["A_del"]
            if length(A_del) == 1
                A_del = A_del*ones(n_grid_phi, n_grid_del, n_grid_reg)
            end
        else 
            A_del = ones(n_grid_phi, n_grid_del, n_grid_reg) 
            struc_params["A_del"] = A_del
        end     

        if haskey(struc_params, "psi")
            psi = struc_params["psi"]
        else
            psi(phi_coy,del_coy) = m_psi
        end

        if haskey(struc_params, "tau")
            tau_mat = struc_params["tau"]
        else
            tau_mat_intra = intra .* ones(n_grid_subnat,n_grid_subnat)        # Could generalize to not be the same between regions 
            tau_mat_intra[diagind(tau_mat_intra)] .= 1  # replace diagonal with 1  
            tau_mat_inter = intra .* iceberg .* ones(n_grid_subnat, n_grid_subnat) 
            tau_mat_inter[diagind(tau_mat_inter)] .= [intra^2 * iceberg, iceberg]  # replace diagonal with 1  
            tau_mat = [tau_mat_intra tau_mat_inter;
                        tau_mat_inter tau_mat_intra]
        end
        # note that the origin region corresponds with rows, destination corresponds to columns
            # ih                ch                  if          cf
        # ih  1                 intra               if (intra^2 * inter)  intra * inter
        # ch intra              1                   cf inter * intra      inter
        # if (intra^2 * inter)  intra * inter
        # cf inter * intra      inter


        function matmean(D,g)
            #compute product of data and density
            Dg = D.*g
            #Compute mean
            sum(Dg[.!Base.isnan.(Dg)])
        end
            
        function matmean(D,g,renorm)
            if renorm == 1
                g = g/sum(g[.!Base.isnan.(D)])
            end 
            #compute product of data and density
            Dg = D.*g
            #Compute mean
            return sum(Dg[.!Base.isnan.(Dg)])
        end

        # CDF of the matching function
        function G_xi(x) 
            cdf.(Weibull(s_xi, lam_xi), x)
        end

        # The partial expectation of the matching function
        function E_xi(A)
            E_xi_mat = lam_xi .* (gamma(1+1/s_xi) .- GSL.sf_gamma_inc.(1 .+ (1 ./s_xi), (A ./ lam_xi).^s_xi))
            return E_xi_mat
        end
########## CORRECT DEFINITION OF PHI AND DEL USING M_ALT  
# MIGHT NEED TO REDEFINE PHI TO DIFFER BY REGIONS, IE ADD ANOTHER DIMENSION
# GO AHEAD AND MAKE 3D WITH PHI, DEL, REG
        #Defining the phi dimension
        if haskey(solv_params, "phi")
            phi = solv_params["phi"]
        else
            qt_grid = range(qt_min, stop=qt_max, length=n_grid_phi)
            phi = quantile.(LogNormal(m_phi,v_phi),qt_grid)
            phi = repeat(phi, inner = (1, n_grid_del, n_grid_reg))
        end


        #Defining the del dimension
        if haskey(solv_params, "del")
            del = solv_params["del"]
        else # COULD ALSO USE A symmetric_intra FLAG HERE AND CHANGE qt_max -> trunc_qt
            qt_grid = range(qt_min, stop = qt_max, length=n_grid_del)
            del = quantile.(LogNormal(m_del,v_del), qt_grid)
            del = repeat(del', inner = (n_grid_phi, 1, n_grid_reg))
        end

        #Defining the grid
        if haskey(struc_params, "g_chi")
            g_chi = struc_params["g_chi"]
        else
            M = [m_phi, m_del]              
            V = [v_phi^2 v_cor*v_phi*v_del ; v_cor*v_phi*v_del v_del^2]
            g_chi_fun(phi,del) = pdf(MvLogNormal(M,V), [phi;del])
            # g_chi_alt(phi,del) = pdf(MvLogNormal(M_alt,V_alt), [phi;del])
            g_chi = zeros(n_grid_phi, n_grid_phi, n_grid_reg)
            for n_phi in 1:n_grid_phi
                for n_del in 1:n_grid_del
                    for n_reg in 1:n_grid_reg
                        if phi[n_phi, n_del, n_reg] == 0 || del[n_phi, n_del, n_reg] == 0
                            g_chi[n_phi, n_del, n_reg] = 0
                        else
                            g_chi[n_phi, n_del, n_reg] = g_chi_fun(phi[n_phi, n_del, n_reg], del[n_phi, n_del, n_reg])
                        end
                    end
                end
            end
            g_chi = g_chi ./ (sum(g_chi)) #n_grid_reg * 
        end

        function labor_clearing(Phi_iter, Del_iter, Del_H_iter)            
            tmp_num = Del_iter .* gam .^ (sig - 1)
            
            Xi_numerator = zeros(n_grid_reg)
            P_H = zeros(n_grid_reg)
            for reg in 1:n_grid_reg
                Xi_numerator[reg] = (matmean((tmp_num[:, :, reg]), g_chi[:, :, reg])) .^ (1/sig)
                tau_S = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                for d_reg in 1:n_grid_reg
                    tau_S[:,:,d_reg] = tau_mat[d_reg, reg] .* ones(n_grid_phi, n_grid_del)
                end
                P_H_inner =  (del ./ tau_S) .^ (sig - 1) .* Phi_iter
                P_H[reg] = (mu .* matmean(P_H_inner, g_chi)).^(1/(1-sig))
            end
            
            Xi = Xi_numerator ./ P_H
            
            # each column corresponds to a country
            Xi = reshape(Xi, (n_grid_subnat, n_grid_cntry))
            P_H = reshape(P_H, (n_grid_subnat, n_grid_cntry))
            
            exponent = a_lab * sig / (sig * b_lab * (1-a_lab) - a_lab * (sig-1))

            Xi = Xi .^ exponent

            # for each column, sum all rows of xi to get denominator
            denom = sum(Xi, dims = 1) 
            denom = repeat(denom, inner = (n_grid_subnat, 1))

            L = L_bar .* (Xi ./ denom)
            L = reshape(L, n_grid_reg, 1)
            
            wages = (Del_H_iter ./ L) .^ (1/sig) .* Xi_numerator
            
            P_H = reshape(P_H, n_grid_reg, 1)
            Xi = reshape(Xi, n_grid_reg, 1)

            output = Dict()
            output["wages"] = wages
            output["L"] = L
            output["P_H"] = P_H
            output["Xi"] = Xi
            return output
        end

        function FA_iter(F_iter, Del_H_iter, Del_iter)
            F_new = 1 ./ w .* ((t .* eps .*(sig .- 1) .* phi.^(sig .- 1) .*
                (b .+ F_iter) .^ (sig .- 2) .* (mu - 1) .* Del_H_iter .*
                Del_iter) .^ (1/sig) .- t .* eps) .- b 
            return max.(F_new, 0.001)   #overkill to keep f from being zero                
        end


        function DelHIter(Phi_sol, Del_sol, Del_H_iter, FA)
            #Initialize matrix of labor used for relationship costs with
            #customers
            # D_f_C = zeros(n_grid_phi, n_grid_del, n_grid_reg, n_grid_phi,n_grid_del, n_grid_reg)
                
            # #Compute labor used for relationship costs with customers for 
            # #each firm
            
            # for n_phi in 1:n_grid_phi
            #     for n_del in 1:n_grid_del
            #         for n_reg in 1:n_grid_reg
                        
            #             tau = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            #             for d_reg in 1:n_grid_reg
            #                 tau[:,:,d_reg] = tau_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
            #             end

            #             #matrix of profit generated with customers C
            #             pi_C = mu^(-sig) * (mu-1) * (tau ./ alp).^(1-sig) .* Del_H_iter .* Phi_sol[n_phi,n_del,n_reg] .* Del_sol

            #             #matrix of psi costs
            #             psi_mat = psi(phi[n_phi, n_del, n_reg],del[n_phi, n_del, n_reg])
                            
            #             #matrix of maximum xi values for which relationships 
            #             #with customers are active
            #             xi_max_C = max.((pi_C ./ psi_mat .- bet*nu) ./ (1-bet*nu),0)
                        
            #             #matrix of acceptance probabilities with customers C
            #             a_C = G_xi(xi_max_C)
            #             if IO_flag == true
            #                 a_C .= a_io
            #             end
                            
            #             #matrix of matching probabilities with customers C
            #             m_C = a_C
                            
            #             #matrix of tatal relationship fixed cost with customers
            #             xi_bar_C = E_xi(xi_max_C)
                            
            #             #=check for Inf values of s_xi.*exp(b_xi.*xi_max_c) and 
            #             #replace corresponding expint with zero. 
            #             Inf_index = s_xi.*exp(b_xi.*xi_max_C)==Inf
            #             xi_bar_C[Inf_index] = = -xi_max_C[Inf_index].*(1-G_xi(xi_max_C[Inf_index])) + exp(s_xi)./b_xi.*expint(s_xi);
            #             #Not sure if all of this is applicable in SS???
                            
            #             # labor used for relationship costs with customers (C)
            #             L_f_C[n_phi,n_del] = nu*matmean(m_C.*psi_mat,g_chi,0) + (1-nu)*matmean(psi_mat.*xi_bar_C,g_chi,0)=#
            #             D_f_C[:,:,:,n_phi,n_del,n_reg] = (nu*m_C + (1-nu)*xi_bar_C).*psi_mat
            #         end
            #     end
            # end
            
            kap = w .+ ((t * eps) ./ (b .+ FA))     
            kap = min.(kap, max_kap)

            # #aggregate labor used for relationship fixed costs
            # D_f = zeros(n_grid_phi,n_grid_del,n_grid_reg)
            # for n_phi in 1:n_grid_phi
            #     for n_del in 1:n_grid_del
            #         for n_reg in 1:n_grid_reg
            #             D_f[n_phi,n_del,n_reg] = matmean(D_f_C[:,:,:,n_phi,n_del,n_grid_reg],g_chi,0)
            #         end
            #     end
            # end    
            # Assuming that the Dirty input pays the fixed match cost and abatement,  .+ FA ).* kap
            # don't forget price of the dirty input
            # D_f_bar = matmean((D_f .+ FA) .* kap, g_chi, 0)  # 
                
            tmp_phi = (phi .* A_phi).^(sig-1)
            #new guess of Del_H
            Del_H_new = max.((D)/ matmean(Del_sol .* (kap .^ (-sig)) .* tmp_phi, g_chi, 0), 0.001) 
            # NOTE THAT IF WE MAKE DIRTY INPUT SPECIFIC TO EACH REGION THEN THIS SHOULD BE A VECTOR, NOT A CONSTANT
            return Del_H_new
        end

        function PhiDelIter(Phi_Iter,Del_Iter,Del_H_iter,FA)

            labor_vars = labor_clearing(Phi_Iter, Del_Iter, Del_H_iter)
            
            wages = labor_vars["wages"]

            #initialize new Phi and Del
            Phi_new = zeros(n_grid_phi,n_grid_del,n_grid_reg)
            Del_new = zeros(n_grid_phi,n_grid_del,n_grid_reg)

            #compute new Phi and Del at each grid point
            for n_phi in 1:n_grid_phi #Lim has crazy stuff right here
                for n_del in 1:n_grid_del
                    for n_reg in 1:n_grid_reg
                        
                        kap = w .+ (eps .* t ./ (b .+ FA[n_phi, n_del, n_reg]))
                        kap = min.(kap, max_kap)

                        tau_C = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        tau_S = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        for d_reg in 1:n_grid_reg
                            tau_C[:,:,d_reg] = tau_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            tau_S[:,:,d_reg] = tau_mat[d_reg, n_reg] .* ones(n_grid_phi, n_grid_del, 1)
                        end

                        con = A * mu^(-sig)*(mu-1)*alp^(sig-1) * Del_H_iter
                        #matricies of profit generated with all suppliers and customers
                        pi_S = con .* tau_S .^ (1 - sig) .* Phi_Iter .* Del_Iter[n_phi, n_del, n_reg]
                        pi_C = con .* tau_C .^ (1 - sig) .* Phi_Iter[n_phi, n_del, n_reg] .* Del_Iter
                            
                        #matrix of psi costs
                        psi_mat = psi(phi[n_phi],del[n_del])
                        xi_max_S = max.((pi_S ./ psi_mat .- bet*nu) ./ (1 .- bet*nu),0) 
                        xi_max_C = max.((pi_C ./ psi_mat .- bet*nu) ./ (1 .- bet*nu),0)
                            
                        #matricies of acceptance probabilities with all suppliers and customers
                        a_S = G_xi(xi_max_S)
                        a_C = G_xi(xi_max_C)

                        if IO_flag == true
                            a_S .= a_io
                            a_C .= a_io
                        end
                            
                        #matricies of matching probabilities with all suppliers and customers
                        m_S = a_S
                        m_C = a_C
                            
                        #new guess of Phi and Del
                        Phi_new[n_phi,n_del,n_reg] = (phi[n_phi, n_del, n_reg]*A_phi[n_phi, n_del, n_reg] ./
                            kap)^(sig-1) + (gam ./ wages[n_reg]).^(sig - 1) + 
                            (alp./mu)^(sig-1)*matmean(Phi_Iter.* m_S .* tau_S .^ (1-sig), g_chi,0)
                        Del_new[n_phi,n_del,n_reg] = mu^(-sig)*(del[n_phi, n_del, n_reg]*A_del[n_phi, n_del, n_reg]
                            )^(sig-1) + mu^(-sig)*alp^(sig-1)*matmean(Del_Iter.*m_C .* tau_C .^ (-sig),g_chi,0)
                    end
                end
            end
            output = Dict()
            output["Phi"] = Phi_new
            output["Del"] = Del_new
            return output
        end


        #Initial guess for Del_H
        if haskey(solv_params, "Del_H_guess")
            Del_H_guess = solv_params["Del_H_guess"]
        else
            prod_1 = (phi .* A_phi).^(sig-1)
            prod_2 = ((del .* A_del)).^(sig-1)
            
            #Del_H for empty network
            phidel_bar = mean(prod_1 .* prod_2) 
            Del_H_emp = (mu ^ sig) * D / phidel_bar
                
            #Set initial conditions
            Del_H_guess = Del_H_emp
        end
            
        #Initial Guess/Conditions for Phi and Del
        if haskey(solv_params, "Phi_guess") && haskey(solv_params, "Del_guess")
            Phi_guess = solv_params["Phi_guess"]
            Del_guess = solv_params["Del_guess"]
        else
            #set initial conditions           
            Phi_guess = (phi.*A_phi).^(sig-1)
            Del_guess = mu^(-sig).*(del.*A_del).^(sig-1)
        end

        # Dictionary to store final output
        out = Dict()

        F_guess = 100 * ones(n_grid_phi,n_grid_del,n_grid_reg)
        F_res = Inf
        F_niter = 0

        # Iterate over F
        while F_res > tol && F_niter <= max_iter
            F_niter = F_niter + 1
            
            if F_niter == 1
                FA = F_guess
                out["F"] = FA
            else
                FA = out["F"]
            end

            Del_H = Del_H_guess
                
            #Initialize Del_H risidual and iteration counter
            Del_H_res = Inf 
            Del_H_niter = 0
                

            while Del_H_res > tol && Del_H_niter <= max_iter
                    
                #Increment and display iteration number
                Del_H_niter = Del_H_niter + 1
                
                #Initialize guess of Phi and Del on the first iteration
                
                if Del_H_niter == 1
                    Phi = Phi_guess
                    Del = Del_guess
                else 
                    Phi = out["Phi"]
                    Del = out["Del"]
                end

                PhiDel_res = Inf
                PhiDel_iter = 0
                
                while PhiDel_res > tol && PhiDel_iter <= max_iter
                    #incriment ieration counter
                    PhiDel_iter = PhiDel_iter + 1
                        
                    #Iterate on (Phi,Del) functional eq. to compute new guess of (Phi,Del)
                    out_PhiDel = PhiDelIter(Phi,Del,Del_H,FA) 
                    Phi_new = out_PhiDel["Phi"] 
                    Del_new = out_PhiDel["Del"]
                        
                    #Compute PhiDel Residual 
                    Phi_res = maximum(abs.(Phi-Phi_new))
                    Del_res = maximum(abs.(Del-Del_new))
                    PhiDel_res = max(Phi_res,Del_res) 
                        
                    #Update guess of Phi and Del if needed 
                    if PhiDel_res > tol
                        Phi = Phi_new
                        Del = Del_new 
                    else
                        out["Phi"] = Phi
                        out["Del"] = Del
                    end
                    
                end

                if PhiDel_res > tol
                    println("warning: Phi/Del iteration failed to converge")
                    sleep(1)   
                end

                Del_H_new = DelHIter(Phi, Del, Del_H, FA) 
                
                #Compute and display Del_H Residual
                Del_H_res = abs.(Del_H - Del_H_new)

                #Update guess of Del_H if needed
                if Del_H_res > tol
                    Del_H = update .* Del_H_new .+ (1 .- update) .* Del_H
                else
                    out["Del_H"] = Del_H
                end

            end
            
            if Del_H_res > tol
                println("warning: Del_H iteration failed to converge")
                sleep(1)
            end

            #Initialize PhiDel residual and iteration counter
            Del = out["Del"]
            Del_H = out["Del_H"]

            #Compute new guess of Del_H
            F_new = FA_iter(FA, Del_H, Del)
            F_res = maximum(abs.(FA - F_new))
            # println("FA residual: $F_res")
            F_new = update .* F_new .+ (1 - update) .* FA
            out["F"] = F_new    
                
        end

        if F_res > tol && F_niter >= max_iter
            println("F DID NOT CONVERGE")
        end 

        Phi = out["Phi"]
        Del = out["Del"]
        Del_H = out["Del_H"]
        FA = out["F"]

        kap = w .+ ((t * eps) ./ (b .+ FA))     
        
        kap_out = min.(kap, max_kap)

        dirty_input = (Del_H .* Del .* (kap_out) .^ (-sig) .* phi .^ (sig - 1))
        emissions = eps .* dirty_input ./ (b .+ FA)
        ttl_output = Del_H .* Del .* Phi .^ (sig/(sig - 1))
        e_intensity = emissions ./ ttl_output
        


        function intermediate_inputs(Del_H, Phi, Del, FA)
            Output = Dict()

            mass_cus = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            mass_sup = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            
            intermediates_sold_q = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            intermediates_purchased_q = zeros(n_grid_phi, n_grid_del, n_grid_reg)

            intermediates_sold_val = zeros(n_grid_phi, n_grid_del, n_grid_reg)
            intermediates_purchased_val = zeros(n_grid_phi, n_grid_del, n_grid_reg)

            for n_phi in 1:n_grid_phi
                for n_del in 1:n_grid_del
                    for n_reg in 1:n_grid_reg
                    
                        tau_C = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        tau_S = zeros(n_grid_phi, n_grid_del, n_grid_reg)
                        for d_reg in 1:n_grid_reg
                            tau_C[:,:,d_reg] = tau_mat[n_reg, d_reg] .* ones(n_grid_phi, n_grid_del, 1)
                            tau_S[:,:,d_reg] = tau_mat[d_reg, n_reg] .* ones(n_grid_phi, n_grid_del, 1)
                        end

                        con = mu^(-sig)*(mu-1)*alp^(sig-1) * Del_H
                        #matricies of profit generated with all suppliers and customers
                        pi_S = con .* tau_S .^ (1 - sig) .* Phi .* Del[n_phi, n_del, n_reg]
                        pi_C = con .* tau_C .^ (1 - sig) .* Phi[n_phi, n_del, n_reg] .* Del
                            
                        #matrix of psi costs
                        psi_mat = psi(phi[n_phi, n_del, n_reg],del[n_phi, n_del, n_reg])
                        xi_max_S = max.((pi_S ./ psi_mat .- bet*nu) ./ (1 .- bet*nu),0) 
                        xi_max_C = max.((pi_C ./ psi_mat .- bet*nu) ./ (1 .- bet*nu),0)
                            
                        #matricies of acceptance probabilities with all suppliers and customers
                        a_S = G_xi(xi_max_S)
                        a_C = G_xi(xi_max_C)
                            
                        if IO_flag == true
                            a_S .= a_io 
                            a_C .= a_io
                        end

                        #matricies of matching probabilities with all suppliers and customers
                        m_S = a_S
                        m_C = a_C
                        
                        mass_sup[n_phi, n_del, n_reg] = matmean(m_S, g_chi)
                        mass_cus[n_phi, n_del, n_reg] = matmean(m_C, g_chi)
                        

                        intermediates_sold_q[n_phi, n_del, n_reg] = Del_H * mu ^ (-sig) * 
                            Phi[n_phi, n_del, n_reg] ^ (sig / (sig - 1)) * 
                            matmean(m_C .* tau_C .^ (-sig) .* Del, g_chi)
                        intermediates_purchased_q[n_phi, n_del, n_reg] = Del_H * mu ^ (-sig) * 
                            Del[n_phi, n_del, n_reg] * 
                            matmean(m_S .* tau_S .^ (-sig) .* Phi .^ (sig / (sig - 1)), g_chi)

                        intermediates_sold_val[n_phi, n_del, n_reg] = Del_H * mu ^ (1 - sig) * 
                            Phi[n_phi, n_del, n_reg] * matmean(m_C .* tau_C .^ (1 - sig) .* Del, g_chi)
                        intermediates_purchased_val[n_phi, n_del, n_reg] = Del_H * mu ^ (1 - sig) * 
                            Del[n_phi, n_del, n_reg] * matmean(m_S .* tau_S .^ (1 - sig) .* Phi, g_chi)
                    end
                end
            end
            Output["mass_sup"] = mass_sup
            Output["mass_cus"] = mass_cus
            Output["intermediates_sold_q"] = intermediates_sold_q
            Output["intermediates_purchased_q"] = intermediates_purchased_q
            Output["intermediates_sold_val"] = intermediates_sold_q
            Output["intermediates_purchased_val"] = intermediates_purchased_q
            return Output 
        end


        function region_id(n_regions)
            reg = zeros(n_grid_phi, n_grid_del, n_regions)
            for reg_id in 1:n_regions
                reg[:,:,reg_id] .= reg_id
            end    
            return reg[:]
        end

        intermediate_info = intermediate_inputs(Del_H, Phi, Del, FA)
        Labor_vars = labor_clearing(Phi, Del, Del_H)
        
        # Reshape wages
        wages = zeros(n_grid_phi, n_grid_del, n_grid_reg)
        P_H = zeros(n_grid_phi, n_grid_del, n_grid_reg)
        for reg = 1:n_grid_reg
            t_w = Labor_vars["wages"][reg]
            t_p = Labor_vars["P_H"][reg]
            wages[:, :, reg] .= t_w
            P_H[:, :, reg] .= t_p
        end

        labor_dem = Del_H .* Del .* (wages) .^ (-sig) .* gam .^ (sig - 1)

        country = repeat(1:n_grid_cntry, inner = n_grid_phi * n_grid_del * n_grid_subnat)
        subnat = repeat(1:n_grid_subnat, inner =n_grid_phi * n_grid_del, outer = n_grid_cntry)

        f_output = DataFrame(
            phi = phi[:],
            del = del[:], 
            country = country,
            subnat = subnat,
            reg = region_id(n_grid_reg),
            Phi_network = Phi[:],
            Del_network = Del[:],
            kappa = kap_out[:],
            dirty_input = dirty_input[:],
            emissions = emissions[:],
            ttl_output = ttl_output[:],
            FA = FA[:],
            labor = labor_dem[:],
            wages = wages[:], 
            P_H = P_H[:],
            e_intensity = e_intensity[:],
            inter_trade_cost = iceberg .* ones(n_grid_phi * n_grid_del * n_grid_reg),
            intra_trade_cost = intra .* ones(n_grid_phi * n_grid_del * n_grid_reg),
            emissions_tax = t .* ones(n_grid_phi * n_grid_del * n_grid_reg), 
            Del_H = Del_H .* ones(n_grid_phi * n_grid_del * n_grid_reg),
            mass_cus = intermediate_info["mass_cus"][:],
            mass_sup = intermediate_info["mass_sup"][:],
            int_sold_q = intermediate_info["intermediates_sold_q"][:],
            int_purch_q = intermediate_info["intermediates_purchased_q"][:],
            int_sold_val = intermediate_info["intermediates_sold_val"][:],
            int_purch_val = intermediate_info["intermediates_purchased_val"][:], 
            g_chi = g_chi[:]
        )
        append!(output_df, f_output)
    end 
    return output_df
end


firm_output = @distributed vcat for iceberg = inter_trade_costs
   intra_solve(iceberg, IO_flag);
end;

file_name = "appendix_model.csv"

if IO_flag == true
    file_name = string("IO_", file_name)
end

using CSV
CSV.write(file_name, firm_output)

